#
#  Copyright 2025 The InfiniFlow Authors. All Rights Reserved.
#
#  Licensed under the Apache License, Version 2.0 (the "License");
#  you may not use this file except in compliance with the License.
#  You may obtain a copy of the License at
#
#      http://www.apache.org/licenses/LICENSE-2.0
#
#  Unless required by applicable law or agreed to in writing, software
#  distributed under the License is distributed on an "AS IS" BASIS,
#  WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
#  See the License for the specific language governing permissions and
#  limitations under the License.
#
import json
import logging
import os
import re
import time
from typing import Any, Optional

from elasticsearch_dsl import Q, Search
from pydantic import BaseModel
from pymysql.converters import escape_string
from pyobvector import ObVecClient, FtsIndexParam, FtsParser, ARRAY, VECTOR
from pyobvector.client.hybrid_search import HybridSearch
from pyobvector.util import ObVersion
from sqlalchemy import text, Column, String, Integer, JSON, Double, Row, Table
from sqlalchemy.dialects.mysql import LONGTEXT, TEXT
from sqlalchemy.sql.type_api import TypeEngine

from common import settings
from common.constants import PAGERANK_FLD, TAG_FLD
from common.decorator import singleton
from common.float_utils import get_float
from rag.nlp import rag_tokenizer
from rag.utils.doc_store_conn import DocStoreConnection, MatchExpr, OrderByExpr, FusionExpr, MatchTextExpr, \
    MatchDenseExpr

ATTEMPT_TIME = 2
OB_QUERY_TIMEOUT = int(os.environ.get("OB_QUERY_TIMEOUT", "100_000_000"))

logger = logging.getLogger('ragflow.ob_conn')

column_order_id = Column("_order_id", Integer, nullable=True, comment="chunk order id for maintaining sequence")
column_group_id = Column("group_id", String(256), nullable=True, comment="group id for external retrieval")

column_definitions: list[Column] = [
    Column("id", String(256), primary_key=True, comment="chunk id"),
    Column("kb_id", String(256), nullable=False, index=True, comment="knowledge base id"),
    Column("doc_id", String(256), nullable=True, index=True, comment="document id"),
    Column("docnm_kwd", String(256), nullable=True, comment="document name"),
    Column("doc_type_kwd", String(256), nullable=True, comment="document type"),
    Column("title_tks", String(256), nullable=True, comment="title tokens"),
    Column("title_sm_tks", String(256), nullable=True, comment="fine-grained (small) title tokens"),
    Column("content_with_weight", LONGTEXT, nullable=True, comment="the original content"),
    Column("content_ltks", LONGTEXT, nullable=True, comment="long text tokens derived from content_with_weight"),
    Column("content_sm_ltks", LONGTEXT, nullable=True, comment="fine-grained (small) tokens derived from content_ltks"),
    Column("pagerank_fea", Integer, nullable=True, comment="page rank priority, usually set in kb level"),
    Column("important_kwd", ARRAY(String(256)), nullable=True, comment="keywords"),
    Column("important_tks", TEXT, nullable=True, comment="keyword tokens"),
    Column("question_kwd", ARRAY(String(1024)), nullable=True, comment="questions"),
    Column("question_tks", TEXT, nullable=True, comment="question tokens"),
    Column("tag_kwd", ARRAY(String(256)), nullable=True, comment="tags"),
    Column("tag_feas", JSON, nullable=True,
           comment="tag features used for 'rank_feature', format: [tag -> relevance score]"),
    Column("available_int", Integer, nullable=False, index=True, server_default="1",
           comment="status of availability, 0 for unavailable, 1 for available"),
    Column("create_time", String(19), nullable=True, comment="creation time in YYYY-MM-DD HH:MM:SS format"),
    Column("create_timestamp_flt", Double, nullable=True, comment="creation timestamp in float format"),
    Column("img_id", String(128), nullable=True, comment="image id"),
    Column("position_int", ARRAY(ARRAY(Integer)), nullable=True, comment="position"),
    Column("page_num_int", ARRAY(Integer), nullable=True, comment="page number"),
    Column("top_int", ARRAY(Integer), nullable=True, comment="rank from the top"),
    Column("knowledge_graph_kwd", String(256), nullable=True, index=True, comment="knowledge graph chunk type"),
    Column("source_id", ARRAY(String(256)), nullable=True, comment="source document id"),
    Column("entity_kwd", String(256), nullable=True, comment="entity name"),
    Column("entity_type_kwd", String(256), nullable=True, index=True, comment="entity type"),
    Column("from_entity_kwd", String(256), nullable=True, comment="the source entity of this edge"),
    Column("to_entity_kwd", String(256), nullable=True, comment="the target entity of this edge"),
    Column("weight_int", Integer, nullable=True, comment="the weight of this edge"),
    Column("weight_flt", Double, nullable=True, comment="the weight of community report"),
    Column("entities_kwd", ARRAY(String(256)), nullable=True, comment="node ids of entities"),
    Column("rank_flt", Double, nullable=True, comment="rank of this entity"),
    Column("removed_kwd", String(256), nullable=True, index=True, server_default="'N'",
           comment="whether it has been deleted"),
    Column("metadata", JSON, nullable=True, comment="metadata for this chunk"),
    Column("extra", JSON, nullable=True, comment="extra information of non-general chunk"),
    column_order_id,
    column_group_id,
]

column_names: list[str] = [col.name for col in column_definitions]
column_types: dict[str, TypeEngine] = {col.name: col.type for col in column_definitions}
array_columns: list[str] = [col.name for col in column_definitions if isinstance(col.type, ARRAY)]

vector_column_pattern = re.compile(r"q_(?P<vector_size>\d+)_vec")

index_columns: list[str] = [
    "kb_id",
    "doc_id",
    "available_int",
    "knowledge_graph_kwd",
    "entity_type_kwd",
    "removed_kwd",
]

fulltext_search_columns: list[str] = [
    "docnm_kwd",
    "content_with_weight",
    "title_tks",
    "title_sm_tks",
    "important_tks",
    "question_tks",
    "content_ltks",
    "content_sm_ltks"
]

fts_columns_origin: list[str] = [
    "docnm_kwd^10",
    "content_with_weight",
    "important_tks^20",
    "question_tks^20",
]

fts_columns_tks: list[str] = [
    "title_tks^10",
    "title_sm_tks^5",
    "important_tks^20",
    "question_tks^20",
    "content_ltks^2",
    "content_sm_ltks",
]

index_name_template = "ix_%s_%s"
fulltext_index_name_template = "fts_idx_%s"
# MATCH AGAINST: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002017607
fulltext_search_template = "MATCH (%s) AGAINST ('%s' IN NATURAL LANGUAGE MODE)"
# cosine_distance: https://www.oceanbase.com/docs/common-oceanbase-database-cn-1000000002012938
vector_search_template = "cosine_distance(%s, %s)"


class SearchResult(BaseModel):
    total: int
    chunks: list[dict]


def get_column_value(column_name: str, value: Any) -> Any:
    if column_name in column_types:
        column_type = column_types[column_name]
        if isinstance(column_type, String):
            return str(value)
        elif isinstance(column_type, Integer):
            return int(value)
        elif isinstance(column_type, Double):
            return float(value)
        elif isinstance(column_type, ARRAY) or isinstance(column_type, JSON):
            if isinstance(value, str):
                try:
                    return json.loads(value)
                except json.JSONDecodeError:
                    return value
            else:
                return value
        else:
            raise ValueError(f"Unsupported column type for column '{column_name}': {column_type}")
    elif vector_column_pattern.match(column_name):
        if isinstance(value, str):
            try:
                return json.loads(value)
            except json.JSONDecodeError:
                return value
        else:
            return value
    elif column_name == "_score":
        return float(value)
    else:
        raise ValueError(f"Unknown column '{column_name}' with value '{value}'.")


def get_default_value(column_name: str) -> Any:
    if column_name == "available_int":
        return 1
    elif column_name == "removed_kwd":
        return "N"
    elif column_name == "_order_id":
        return 0
    else:
        return None


def get_value_str(value: Any) -> str:
    if isinstance(value, str):
        cleaned_str = value.replace('\\', '\\\\')
        cleaned_str = cleaned_str.replace('\n', '\\n')
        cleaned_str = cleaned_str.replace('\r', '\\r')
        cleaned_str = cleaned_str.replace('\t', '\\t')
        return f"'{escape_string(cleaned_str)}'"
    elif isinstance(value, bool):
        return "true" if value else "false"
    elif value is None:
        return "NULL"
    elif isinstance(value, (list, dict)):
        json_str = json.dumps(value, ensure_ascii=False)
        return f"'{escape_string(json_str)}'"
    else:
        return str(value)


def get_metadata_filter_expression(metadata_filtering_conditions: dict) -> str:
    """
    Convert metadata filtering conditions to MySQL JSON path expression.

    Args:
        metadata_filtering_conditions: dict with 'conditions' and 'logical_operator' keys

    Returns:
        MySQL JSON path expression string
    """
    if not metadata_filtering_conditions:
        return ""

    conditions = metadata_filtering_conditions.get("conditions", [])
    logical_operator = metadata_filtering_conditions.get("logical_operator", "and").upper()

    if not conditions:
        return ""

    if logical_operator not in ["AND", "OR"]:
        raise ValueError(f"Unsupported logical operator: {logical_operator}. Only 'and' and 'or' are supported.")

    metadata_filters = []
    for condition in conditions:
        name = condition.get("name")
        comparison_operator = condition.get("comparison_operator")
        value = condition.get("value")

        if not all([name, comparison_operator]):
            continue

        expr = f"JSON_EXTRACT(metadata, '$.{name}')"
        value_str = get_value_str(value) if value else ""

        # Convert comparison operator to MySQL JSON path syntax
        if comparison_operator == "is":
            # JSON_EXTRACT(metadata, '$.field_name') = 'value'
            metadata_filters.append(f"{expr} = {value_str}")
        elif comparison_operator == "is not":
            metadata_filters.append(f"{expr} != {value_str}")
        elif comparison_operator == "contains":
            metadata_filters.append(f"JSON_CONTAINS({expr}, {value_str})")
        elif comparison_operator == "not contains":
            metadata_filters.append(f"NOT JSON_CONTAINS({expr}, {value_str})")
        elif comparison_operator == "start with":
            metadata_filters.append(f"{expr} LIKE CONCAT({value_str}, '%')")
        elif comparison_operator == "end with":
            metadata_filters.append(f"{expr} LIKE CONCAT('%', {value_str})")
        elif comparison_operator == "empty":
            metadata_filters.append(f"({expr} IS NULL OR {expr} = '' OR {expr} = '[]' OR {expr} = '{{}}')")
        elif comparison_operator == "not empty":
            metadata_filters.append(f"({expr} IS NOT NULL AND {expr} != '' AND {expr} != '[]' AND {expr} != '{{}}')")
        # Number operators
        elif comparison_operator == "=":
            metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) = {value_str}")
        elif comparison_operator == "≠":
            metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) != {value_str}")
        elif comparison_operator == ">":
            metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) > {value_str}")
        elif comparison_operator == "<":
            metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) < {value_str}")
        elif comparison_operator == "≥":
            metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) >= {value_str}")
        elif comparison_operator == "≤":
            metadata_filters.append(f"CAST({expr} AS DECIMAL(20,10)) <= {value_str}")
        # Time operators
        elif comparison_operator == "before":
            metadata_filters.append(f"CAST({expr} AS DATETIME) < {value_str}")
        elif comparison_operator == "after":
            metadata_filters.append(f"CAST({expr} AS DATETIME) > {value_str}")
        else:
            logger.warning(f"Unsupported comparison operator: {comparison_operator}")
            continue

    if not metadata_filters:
        return ""

    return f"({f' {logical_operator} '.join(metadata_filters)})"


def get_filters(condition: dict) -> list[str]:
    filters: list[str] = []
    for k, v in condition.items():
        if not v:
            continue

        if k == "exists":
            filters.append(f"{v} IS NOT NULL")
        elif k == "must_not" and isinstance(v, dict) and "exists" in v:
            filters.append(f"{v.get('exists')} IS NULL")
        elif k == "metadata_filtering_conditions":
            # Handle metadata filtering conditions
            metadata_filter = get_metadata_filter_expression(v)
            if metadata_filter:
                filters.append(metadata_filter)
        elif k in array_columns:
            if isinstance(v, list):
                array_filters = []
                for vv in v:
                    array_filters.append(f"array_contains({k}, {get_value_str(vv)})")
                array_filter = " OR ".join(array_filters)
                filters.append(f"({array_filter})")
            else:
                filters.append(f"array_contains({k}, {get_value_str(v)})")
        elif isinstance(v, list):
            values: list[str] = []
            for item in v:
                values.append(get_value_str(item))
            value = ", ".join(values)
            filters.append(f"{k} IN ({value})")
        else:
            filters.append(f"{k} = {get_value_str(v)}")
    return filters


def _try_with_lock(lock_name: str, process_func, check_func, timeout: int = None):
    if not timeout:
        timeout = int(os.environ.get("OB_DDL_TIMEOUT", "60"))

    if not check_func():
        from rag.utils.redis_conn import RedisDistributedLock
        lock = RedisDistributedLock(lock_name)
        if lock.acquire():
            logger.info(f"acquired lock success: {lock_name}, start processing.")
            try:
                process_func()
                return
            finally:
                lock.release()

    if not check_func():
        logger.info(f"Waiting for process complete for {lock_name} on other task executors.")
        time.sleep(1)
        count = 1
        while count < timeout and not check_func():
            count += 1
            time.sleep(1)
        if count >= timeout and not check_func():
            raise Exception(f"Timeout to wait for process complete for {lock_name}.")


@singleton
class OBConnection(DocStoreConnection):
    def __init__(self):
        scheme: str = settings.OB.get("scheme")
        ob_config = settings.OB.get("config", {})

        if scheme and scheme.lower() == "mysql":
            mysql_config = settings.get_base_config("mysql", {})
            logger.info("Use MySQL scheme to create OceanBase connection.")
            host = mysql_config.get("host", "localhost")
            port = mysql_config.get("port", 2881)
            self.username = mysql_config.get("user", "root@test")
            self.password = mysql_config.get("password", "infini_rag_flow")
        else:
            logger.info("Use customized config to create OceanBase connection.")
            host = ob_config.get("host", "localhost")
            port = ob_config.get("port", 2881)
            self.username = ob_config.get("user", "root@test")
            self.password = ob_config.get("password", "infini_rag_flow")

        self.db_name = ob_config.get("db_name", "test")
        self.uri = f"{host}:{port}"

        logger.info(f"Use OceanBase '{self.uri}' as the doc engine.")

        for _ in range(ATTEMPT_TIME):
            try:
                self.client = ObVecClient(
                    uri=self.uri,
                    user=self.username,
                    password=self.password,
                    db_name=self.db_name,
                    pool_pre_ping=True,
                    pool_recycle=3600,
                )
                break
            except Exception as e:
                logger.warning(f"{str(e)}. Waiting OceanBase {self.uri} to be healthy.")
                time.sleep(5)

        if self.client is None:
            msg = f"OceanBase {self.uri} connection failed after {ATTEMPT_TIME} attempts."
            logger.error(msg)
            raise Exception(msg)

        self._load_env_vars()
        self._check_ob_version()
        self._try_to_update_ob_query_timeout()

        logger.info(f"OceanBase {self.uri} is healthy.")

    def _check_ob_version(self):
        try:
            res = self.client.perform_raw_text_sql("SELECT OB_VERSION() FROM DUAL").fetchone()
            version_str = res[0] if res else None
            logger.info(f"OceanBase {self.uri} version is {version_str}")
        except Exception as e:
            raise Exception(f"Failed to get OceanBase version from {self.uri}, error: {str(e)}")

        if not version_str:
            raise Exception(f"Failed to get OceanBase version from {self.uri}.")

        ob_version = ObVersion.from_db_version_string(version_str)
        if ob_version < ObVersion.from_db_version_nums(4, 3, 5, 1):
            raise Exception(
                f"The version of OceanBase needs to be higher than or equal to 4.3.5.1, current version is {version_str}"
            )

        self.es = None
        if not ob_version < ObVersion.from_db_version_nums(4, 4, 1, 0) and self.enable_hybrid_search:
            self.es = HybridSearch(
                uri=self.uri,
                user=self.username,
                password=self.password,
                db_name=self.db_name,
                pool_pre_ping=True,
                pool_recycle=3600,
            )
            logger.info("OceanBase Hybrid Search feature is enabled")

    def _try_to_update_ob_query_timeout(self):
        try:
            val = self._get_variable_value("ob_query_timeout")
            if val and int(val) >= OB_QUERY_TIMEOUT:
                return
        except Exception as e:
            logger.warning("Failed to get 'ob_query_timeout' variable: %s", str(e))

        try:
            self.client.perform_raw_text_sql(f"SET GLOBAL ob_query_timeout={OB_QUERY_TIMEOUT}")
            logger.info("Set GLOBAL variable 'ob_query_timeout' to %d.", OB_QUERY_TIMEOUT)

            # refresh connection pool to ensure 'ob_query_timeout' has taken effect
            self.client.engine.dispose()
            if self.es is not None:
                self.es.engine.dispose()
            logger.info("Disposed all connections in engine pool to refresh connection pool")
        except Exception as e:
            logger.warning(f"Failed to set 'ob_query_timeout' variable: {str(e)}")

    def _load_env_vars(self):

        def is_true(var: str, default: str) -> bool:
            return os.getenv(var, default).lower() in ['true', '1', 'yes', 'y']

        self.enable_fulltext_search = is_true('ENABLE_FULLTEXT_SEARCH', 'true')
        self.use_fulltext_hint = is_true('USE_FULLTEXT_HINT', 'true')
        self.search_original_content = is_true("SEARCH_ORIGINAL_CONTENT", 'true')
        self.enable_hybrid_search = is_true('ENABLE_HYBRID_SEARCH', 'false')

    """
    Database operations
    """

    def dbType(self) -> str:
        return "oceanbase"

    def health(self) -> dict:
        return {
            "uri": self.uri,
            "version_comment": self._get_variable_value("version_comment")
        }

    def _get_variable_value(self, var_name: str) -> Any:
        rows = self.client.perform_raw_text_sql(f"SHOW VARIABLES LIKE '{var_name}'")
        for row in rows:
            return row[1]
        raise Exception(f"Variable '{var_name}' not found.")

    """
    Table operations
    """

    def createIdx(self, indexName: str, knowledgebaseId: str, vectorSize: int):
        vector_field_name = f"q_{vectorSize}_vec"
        vector_index_name = f"{vector_field_name}_idx"

        try:
            _try_with_lock(
                lock_name=f"ob_create_table_{indexName}",
                check_func=lambda: self.client.check_table_exists(indexName),
                process_func=lambda: self._create_table(indexName),
            )

            for column_name in index_columns:
                _try_with_lock(
                    lock_name=f"ob_add_idx_{indexName}_{column_name}",
                    check_func=lambda: self._index_exists(indexName, index_name_template % (indexName, column_name)),
                    process_func=lambda: self._add_index(indexName, column_name),
                )

            fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
            for fts_column in fts_columns:
                column_name = fts_column.split("^")[0]
                _try_with_lock(
                    lock_name=f"ob_add_fulltext_idx_{indexName}_{column_name}",
                    check_func=lambda: self._index_exists(indexName, fulltext_index_name_template % column_name),
                    process_func=lambda: self._add_fulltext_index(indexName, column_name),
                )

            _try_with_lock(
                lock_name=f"ob_add_vector_column_{indexName}_{vector_field_name}",
                check_func=lambda: self._column_exist(indexName, vector_field_name),
                process_func=lambda: self._add_vector_column(indexName, vectorSize),
            )

            _try_with_lock(
                lock_name=f"ob_add_vector_idx_{indexName}_{vector_field_name}",
                check_func=lambda: self._index_exists(indexName, vector_index_name),
                process_func=lambda: self._add_vector_index(indexName, vector_field_name),
            )

            # new columns migration
            for column in [column_order_id, column_group_id]:
                _try_with_lock(
                    lock_name=f"ob_add_{column.name}_{indexName}",
                    check_func=lambda: self._column_exist(indexName, column.name),
                    process_func=lambda: self._add_column(indexName, column),
                )
        except Exception as e:
            raise Exception(f"OBConnection.createIndex error: {str(e)}")
        finally:
            # always refresh metadata to make sure it contains the latest table structure
            self.client.refresh_metadata([indexName])

    def deleteIdx(self, indexName: str, knowledgebaseId: str):
        if len(knowledgebaseId) > 0:
            # The index need to be alive after any kb deletion since all kb under this tenant are in one index.
            return
        try:
            if self.client.check_table_exists(table_name=indexName):
                self.client.drop_table_if_exist(indexName)
                logger.info(f"Dropped table '{indexName}'.")
        except Exception as e:
            raise Exception(f"OBConnection.deleteIndex error: {str(e)}")

    def indexExist(self, indexName: str, knowledgebaseId: str = None) -> bool:
        try:
            if not self.client.check_table_exists(indexName):
                return False
            for column_name in index_columns:
                if not self._index_exists(indexName, index_name_template % (indexName, column_name)):
                    return False
            fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks
            for fts_column in fts_columns:
                column_name = fts_column.split("^")[0]
                if not self._index_exists(indexName, fulltext_index_name_template % column_name):
                    return False
            for column in [column_order_id, column_group_id]:
                if not self._column_exist(indexName, column.name):
                    return False
        except Exception as e:
            raise Exception(f"OBConnection.indexExist error: {str(e)}")

        return True

    def _get_count(self, table_name: str, filter_list: list[str] = None) -> int:
        where_clause = "WHERE " + " AND ".join(filter_list) if len(filter_list) > 0 else ""
        (count,) = self.client.perform_raw_text_sql(
            f"SELECT COUNT(*) FROM {table_name} {where_clause}"
        ).fetchone()
        return count

    def _column_exist(self, table_name: str, column_name: str) -> bool:
        return self._get_count(
            table_name="INFORMATION_SCHEMA.COLUMNS",
            filter_list=[
                f"TABLE_SCHEMA = '{self.db_name}'",
                f"TABLE_NAME = '{table_name}'",
                f"COLUMN_NAME = '{column_name}'",
            ]) > 0

    def _index_exists(self, table_name: str, index_name: str) -> bool:
        return self._get_count(
            table_name="INFORMATION_SCHEMA.STATISTICS",
            filter_list=[
                f"TABLE_SCHEMA = '{self.db_name}'",
                f"TABLE_NAME = '{table_name}'",
                f"INDEX_NAME = '{index_name}'",
            ]) > 0

    def _create_table(self, table_name: str):
        # remove outdated metadata for external changes
        if table_name in self.client.metadata_obj.tables:
            self.client.metadata_obj.remove(Table(table_name, self.client.metadata_obj))

        table_options = {
            "mysql_charset": "utf8mb4",
            "mysql_collate": "utf8mb4_unicode_ci",
            "mysql_organization": "heap",
        }

        self.client.create_table(
            table_name=table_name,
            columns=column_definitions,
            **table_options,
        )
        logger.info(f"Created table '{table_name}'.")

    def _add_index(self, table_name: str, column_name: str):
        index_name = index_name_template % (table_name, column_name)
        self.client.create_index(
            table_name=table_name,
            is_vec_index=False,
            index_name=index_name,
            column_names=[column_name],
        )
        logger.info(f"Created index '{index_name}' on table '{table_name}'.")

    def _add_fulltext_index(self, table_name: str, column_name: str):
        fulltext_index_name = fulltext_index_name_template % column_name
        self.client.create_fts_idx_with_fts_index_param(
            table_name=table_name,
            fts_idx_param=FtsIndexParam(
                index_name=fulltext_index_name,
                field_names=[column_name],
                parser_type=FtsParser.IK,
            ),
        )
        logger.info(f"Created full text index '{fulltext_index_name}' on table '{table_name}'.")

    def _add_vector_column(self, table_name: str, vector_size: int):
        vector_field_name = f"q_{vector_size}_vec"

        self.client.add_columns(
            table_name=table_name,
            columns=[Column(vector_field_name, VECTOR(vector_size), nullable=True)],
        )
        logger.info(f"Added vector column '{vector_field_name}' to table '{table_name}'.")

    def _add_vector_index(self, table_name: str, vector_field_name: str):
        vector_index_name = f"{vector_field_name}_idx"
        self.client.create_index(
            table_name=table_name,
            is_vec_index=True,
            index_name=vector_index_name,
            column_names=[vector_field_name],
            vidx_params="distance=cosine, type=hnsw, lib=vsag",
        )
        logger.info(
            f"Created vector index '{vector_index_name}' on table '{table_name}' with column '{vector_field_name}'."
        )

    def _add_column(self, table_name: str, column: Column):
        try:
            self.client.add_columns(
                table_name=table_name,
                columns=[column],
            )
            logger.info(f"Added column '{column.name}' to table '{table_name}'.")
        except Exception as e:
            logger.warning(f"Failed to add column '{column.name}' to table '{table_name}': {str(e)}")

    """
    CRUD operations
    """

    def search(
        self,
        selectFields: list[str],
        highlightFields: list[str],
        condition: dict,
        matchExprs: list[MatchExpr],
        orderBy: OrderByExpr,
        offset: int,
        limit: int,
        indexNames: str | list[str],
        knowledgebaseIds: list[str],
        aggFields: list[str] = [],
        rank_feature: dict | None = None,
        **kwargs,
    ):
        if isinstance(indexNames, str):
            indexNames = indexNames.split(",")
        assert isinstance(indexNames, list) and len(indexNames) > 0
        indexNames = list(set(indexNames))

        if len(matchExprs) == 3:
            if not self.enable_fulltext_search:
                # disable fulltext search in fusion search, which means fallback to vector search
                matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)]
            else:
                for m in matchExprs:
                    if isinstance(m, FusionExpr):
                        weights = m.fusion_params["weights"]
                        vector_similarity_weight = get_float(weights.split(",")[1])
                        # skip the search if its weight is zero
                        if vector_similarity_weight <= 0.0:
                            matchExprs = [m for m in matchExprs if isinstance(m, MatchTextExpr)]
                        elif vector_similarity_weight >= 1.0:
                            matchExprs = [m for m in matchExprs if isinstance(m, MatchDenseExpr)]

        result: SearchResult = SearchResult(
            total=0,
            chunks=[],
        )

        # copied from es_conn.py
        if len(matchExprs) == 3 and self.es:
            bqry = Q("bool", must=[])
            condition["kb_id"] = knowledgebaseIds
            for k, v in condition.items():
                if k == "available_int":
                    if v == 0:
                        bqry.filter.append(Q("range", available_int={"lt": 1}))
                    else:
                        bqry.filter.append(
                            Q("bool", must_not=Q("range", available_int={"lt": 1})))
                    continue
                if not v:
                    continue
                if isinstance(v, list):
                    bqry.filter.append(Q("terms", **{k: v}))
                elif isinstance(v, str) or isinstance(v, int):
                    bqry.filter.append(Q("term", **{k: v}))
                else:
                    raise Exception(
                        f"Condition `{str(k)}={str(v)}` value type is {str(type(v))}, expected to be int, str or list.")

            s = Search()
            vector_similarity_weight = 0.5
            for m in matchExprs:
                if isinstance(m, FusionExpr) and m.method == "weighted_sum" and "weights" in m.fusion_params:
                    assert len(matchExprs) == 3 and isinstance(matchExprs[0], MatchTextExpr) and isinstance(
                        matchExprs[1],
                        MatchDenseExpr) and isinstance(
                        matchExprs[2], FusionExpr)
                    weights = m.fusion_params["weights"]
                    vector_similarity_weight = get_float(weights.split(",")[1])
            for m in matchExprs:
                if isinstance(m, MatchTextExpr):
                    minimum_should_match = m.extra_options.get("minimum_should_match", 0.0)
                    if isinstance(minimum_should_match, float):
                        minimum_should_match = str(int(minimum_should_match * 100)) + "%"
                    bqry.must.append(Q("query_string", fields=fts_columns_tks,
                                       type="best_fields", query=m.matching_text,
                                       minimum_should_match=minimum_should_match,
                                       boost=1))
                    bqry.boost = 1.0 - vector_similarity_weight

                elif isinstance(m, MatchDenseExpr):
                    assert (bqry is not None)
                    similarity = 0.0
                    if "similarity" in m.extra_options:
                        similarity = m.extra_options["similarity"]
                    s = s.knn(m.vector_column_name,
                              m.topn,
                              m.topn * 2,
                              query_vector=list(m.embedding_data),
                              filter=bqry.to_dict(),
                              similarity=similarity,
                              )

            if bqry and rank_feature:
                for fld, sc in rank_feature.items():
                    if fld != PAGERANK_FLD:
                        fld = f"{TAG_FLD}.{fld}"
                    bqry.should.append(Q("rank_feature", field=fld, linear={}, boost=sc))

            if bqry:
                s = s.query(bqry)
            # for field in highlightFields:
            #     s = s.highlight(field)

            if orderBy:
                orders = list()
                for field, order in orderBy.fields:
                    order = "asc" if order == 0 else "desc"
                    if field in ["page_num_int", "top_int"]:
                        order_info = {"order": order, "unmapped_type": "float",
                                      "mode": "avg", "numeric_type": "double"}
                    elif field.endswith("_int") or field.endswith("_flt"):
                        order_info = {"order": order, "unmapped_type": "float"}
                    else:
                        order_info = {"order": order, "unmapped_type": "text"}
                    orders.append({field: order_info})
                s = s.sort(*orders)

            for fld in aggFields:
                s.aggs.bucket(f'aggs_{fld}', 'terms', field=fld, size=1000000)

            if limit > 0:
                s = s[offset:offset + limit]
            q = s.to_dict()
            logger.debug(f"OBConnection.hybrid_search {str(indexNames)} query: " + json.dumps(q))

            for index_name in indexNames:
                start_time = time.time()
                res = self.es.search(index=index_name,
                                     body=q,
                                     timeout="600s",
                                     track_total_hits=True,
                                     _source=True)
                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: hybrid, elapsed time: {elapsed_time:.3f} seconds,"
                    f" got count: {len(res)}"
                )
                for chunk in res:
                    result.chunks.append(self._es_row_to_entity(chunk))
                    result.total = result.total + 1
            return result

        output_fields = selectFields.copy()
        if "id" not in output_fields:
            output_fields = ["id"] + output_fields
        if "_score" in output_fields:
            output_fields.remove("_score")

        if highlightFields:
            for field in highlightFields:
                if field not in output_fields:
                    output_fields.append(field)

        fields_expr = ", ".join(output_fields)

        condition["kb_id"] = knowledgebaseIds
        filters: list[str] = get_filters(condition)
        filters_expr = " AND ".join(filters)

        fulltext_query: Optional[str] = None
        fulltext_topn: Optional[int] = None
        fulltext_search_weight: dict[str, float] = {}
        fulltext_search_expr: dict[str, str] = {}
        fulltext_search_idx_list: list[str] = []
        fulltext_search_score_expr: Optional[str] = None
        fulltext_search_filter: Optional[str] = None

        vector_column_name: Optional[str] = None
        vector_data: Optional[list[float]] = None
        vector_topn: Optional[int] = None
        vector_similarity_threshold: Optional[float] = None
        vector_similarity_weight: Optional[float] = None
        vector_search_expr: Optional[str] = None
        vector_search_score_expr: Optional[str] = None
        vector_search_filter: Optional[str] = None

        for m in matchExprs:
            if isinstance(m, MatchTextExpr):
                assert "original_query" in m.extra_options, "'original_query' is missing in extra_options."
                fulltext_query = m.extra_options["original_query"]
                fulltext_query = escape_string(fulltext_query.strip())
                fulltext_topn = m.topn

                fts_columns = fts_columns_origin if self.search_original_content else fts_columns_tks

                # get fulltext match expression and weight values
                for field in fts_columns:
                    parts = field.split("^")
                    column_name: str = parts[0]
                    column_weight: float = float(parts[1]) if (len(parts) > 1 and parts[1]) else 1.0

                    fulltext_search_weight[column_name] = column_weight
                    fulltext_search_expr[column_name] = fulltext_search_template % (column_name, fulltext_query)
                    fulltext_search_idx_list.append(fulltext_index_name_template % column_name)

                # adjust the weight to 0~1
                weight_sum = sum(fulltext_search_weight.values())
                for column_name in fulltext_search_weight.keys():
                    fulltext_search_weight[column_name] = fulltext_search_weight[column_name] / weight_sum

            elif isinstance(m, MatchDenseExpr):
                assert m.embedding_data_type == "float", f"embedding data type '{m.embedding_data_type}' is not float."
                vector_column_name = m.vector_column_name
                vector_data = m.embedding_data
                vector_topn = m.topn
                vector_similarity_threshold = m.extra_options.get("similarity", 0.0)
            elif isinstance(m, FusionExpr):
                weights = m.fusion_params["weights"]
                vector_similarity_weight = get_float(weights.split(",")[1])

        if fulltext_query:
            fulltext_search_filter = f"({' OR '.join([expr for expr in fulltext_search_expr.values()])})"
            fulltext_search_score_expr = f"({' + '.join(f'{expr} * {fulltext_search_weight.get(col, 0)}' for col, expr in fulltext_search_expr.items())})"

        if vector_data:
            vector_search_expr = vector_search_template % (vector_column_name, vector_data)
            # use (1 - cosine_distance) as score, which should be [-1, 1]
            # https://www.oceanbase.com/docs/common-oceanbase-database-standalone-1000000003577323
            vector_search_score_expr = f"(1 - {vector_search_expr})"
            vector_search_filter = f"{vector_search_score_expr} >= {vector_similarity_threshold}"

        pagerank_score_expr = f"(CAST(IFNULL({PAGERANK_FLD}, 0) AS DECIMAL(10, 2)) / 100)"

        # TODO use tag rank_feature in sorting
        # tag_rank_fea = {k: float(v) for k, v in (rank_feature or {}).items() if k != PAGERANK_FLD}

        if fulltext_query and vector_data:
            search_type = "fusion"
        elif fulltext_query:
            search_type = "fulltext"
        elif vector_data:
            search_type = "vector"
        elif len(aggFields) > 0:
            search_type = "aggregation"
        else:
            search_type = "filter"

        if search_type in ["fusion", "fulltext", "vector"] and "_score" not in output_fields:
            output_fields.append("_score")

        group_results = kwargs.get("group_results", False)

        for index_name in indexNames:

            if not self.client.check_table_exists(index_name):
                continue

            fulltext_search_hint = f"/*+ UNION_MERGE({index_name} {' '.join(fulltext_search_idx_list)}) */" if self.use_fulltext_hint else ""

            if search_type == "fusion":
                # fusion search, usually for chat
                num_candidates = vector_topn + fulltext_topn
                if group_results:
                    count_sql = (
                        f"WITH fulltext_results AS ("
                        f"  SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
                        f"      FROM {index_name}"
                        f"      WHERE {filters_expr} AND {fulltext_search_filter}"
                        f"      ORDER BY relevance DESC"
                        f"      LIMIT {num_candidates}"
                        f"),"
                        f" scored_results AS ("
                        f"  SELECT *"
                        f"      FROM fulltext_results"
                        f"      WHERE {vector_search_filter}"
                        f"),"
                        f" group_results AS ("
                        f"  SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id) as rn"
                        f"      FROM scored_results"
                        f")"
                        f"  SELECT COUNT(*)"
                        f"      FROM group_results"
                        f"      WHERE rn = 1"
                    )
                else:
                    count_sql = (
                        f"WITH fulltext_results AS ("
                        f"  SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
                        f"      FROM {index_name}"
                        f"      WHERE {filters_expr} AND {fulltext_search_filter}"
                        f"      ORDER BY relevance DESC"
                        f"      LIMIT {num_candidates}"
                        f")"
                        f"  SELECT COUNT(*) FROM fulltext_results WHERE {vector_search_filter}"
                    )
                logger.debug("OBConnection.search with count sql: %s", count_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(count_sql)
                total_count = res.fetchone()[0] if res else 0
                result.total += total_count

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: fusion, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
                    f" vector column: '{vector_column_name}',"
                    f" query text: '{fulltext_query}',"
                    f" condition: '{condition}',"
                    f" vector_similarity_threshold: {vector_similarity_threshold},"
                    f" got count: {total_count}"
                )

                if total_count == 0:
                    continue

                score_expr = f"(relevance * {1 - vector_similarity_weight} + {vector_search_score_expr} * {vector_similarity_weight} + {pagerank_score_expr})"
                if group_results:
                    fusion_sql = (
                        f"WITH fulltext_results AS ("
                        f"  SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
                        f"      FROM {index_name}"
                        f"      WHERE {filters_expr} AND {fulltext_search_filter}"
                        f"      ORDER BY relevance DESC"
                        f"      LIMIT {num_candidates}"
                        f"),"
                        f" scored_results AS ("
                        f"  SELECT *, {score_expr} AS _score"
                        f"      FROM fulltext_results"
                        f"      WHERE {vector_search_filter}"
                        f"),"
                        f" group_results AS ("
                        f"  SELECT *, ROW_NUMBER() OVER (PARTITION BY group_id ORDER BY _score DESC) as rn"
                        f"      FROM scored_results"
                        f")"
                        f"  SELECT {fields_expr}, _score"
                        f"      FROM group_results"
                        f"      WHERE rn = 1"
                        f"      ORDER BY _score DESC"
                        f"      LIMIT {offset}, {limit}"
                    )
                else:
                    fusion_sql = (
                        f"WITH fulltext_results AS ("
                        f"  SELECT {fulltext_search_hint} *, {fulltext_search_score_expr} AS relevance"
                        f"      FROM {index_name}"
                        f"      WHERE {filters_expr} AND {fulltext_search_filter}"
                        f"      ORDER BY relevance DESC"
                        f"      LIMIT {num_candidates}"
                        f")"
                        f"  SELECT {fields_expr}, {score_expr} AS _score"
                        f"      FROM fulltext_results"
                        f"      WHERE {vector_search_filter}"
                        f"      ORDER BY _score DESC"
                        f"      LIMIT {offset}, {limit}"
                    )
                logger.debug("OBConnection.search with fusion sql: %s", fusion_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(fusion_sql)
                rows = res.fetchall()

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: fusion, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
                    f" select fields: '{output_fields}',"
                    f" vector column: '{vector_column_name}',"
                    f" query text: '{fulltext_query}',"
                    f" condition: '{condition}',"
                    f" vector_similarity_threshold: {vector_similarity_threshold},"
                    f" vector_similarity_weight: {vector_similarity_weight},"
                    f" return rows count: {len(rows)}"
                )

                for row in rows:
                    result.chunks.append(self._row_to_entity(row, output_fields))
            elif search_type == "vector":
                # vector search, usually used for graph search
                count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr} AND {vector_search_filter}"
                logger.debug("OBConnection.search with vector count sql: %s", count_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(count_sql)
                total_count = res.fetchone()[0] if res else 0
                result.total += total_count

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: vector, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
                    f" vector column: '{vector_column_name}',"
                    f" condition: '{condition}',"
                    f" vector_similarity_threshold: {vector_similarity_threshold},"
                    f" got count: {total_count}"
                )

                if total_count == 0:
                    continue

                vector_sql = (
                    f"SELECT {fields_expr}, {vector_search_score_expr} AS _score"
                    f"  FROM {index_name}"
                    f"  WHERE {filters_expr} AND {vector_search_filter}"
                    f"  ORDER BY {vector_search_expr}"
                    f"  APPROXIMATE LIMIT {limit if limit != 0 else vector_topn}"
                )
                if offset != 0:
                    vector_sql += f" OFFSET {offset}"
                logger.debug("OBConnection.search with vector sql: %s", vector_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(vector_sql)
                rows = res.fetchall()

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: vector, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
                    f" select fields: '{output_fields}',"
                    f" vector column: '{vector_column_name}',"
                    f" condition: '{condition}',"
                    f" vector_similarity_threshold: {vector_similarity_threshold},"
                    f" return rows count: {len(rows)}"
                )

                for row in rows:
                    result.chunks.append(self._row_to_entity(row, output_fields))
            elif search_type == "fulltext":
                # fulltext search, usually used to search chunks in one dataset
                count_sql = f"SELECT {fulltext_search_hint} COUNT(id) FROM {index_name} WHERE {filters_expr} AND {fulltext_search_filter}"
                logger.debug("OBConnection.search with fulltext count sql: %s", count_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(count_sql)
                total_count = res.fetchone()[0] if res else 0
                result.total += total_count

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: fulltext, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
                    f" query text: '{fulltext_query}',"
                    f" condition: '{condition}',"
                    f" got count: {total_count}"
                )

                if total_count == 0:
                    continue

                fulltext_sql = (
                    f"SELECT {fulltext_search_hint} {fields_expr}, {fulltext_search_score_expr} AS _score"
                    f"  FROM {index_name}"
                    f"  WHERE {filters_expr} AND {fulltext_search_filter}"
                    f"  ORDER BY _score DESC"
                    f"  LIMIT {offset}, {limit if limit != 0 else fulltext_topn}"
                )
                logger.debug("OBConnection.search with fulltext sql: %s", fulltext_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(fulltext_sql)
                rows = res.fetchall()

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: fulltext, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
                    f" select fields: '{output_fields}',"
                    f" query text: '{fulltext_query}',"
                    f" condition: '{condition}',"
                    f" return rows count: {len(rows)}"
                )

                for row in rows:
                    result.chunks.append(self._row_to_entity(row, output_fields))
            elif search_type == "aggregation":
                # aggregation search
                assert len(aggFields) == 1, "Only one aggregation field is supported in OceanBase."
                agg_field = aggFields[0]
                if agg_field in array_columns:
                    res = self.client.perform_raw_text_sql(
                        f"SELECT {agg_field} FROM {index_name}"
                        f" WHERE {agg_field} IS NOT NULL AND {filters_expr}"
                    )
                    counts = {}
                    for row in res:
                        if row[0]:
                            if isinstance(row[0], str):
                                try:
                                    arr = json.loads(row[0])
                                except json.JSONDecodeError:
                                    logger.warning(f"Failed to parse JSON array: {row[0]}")
                                    continue
                            else:
                                arr = row[0]

                            if isinstance(arr, list):
                                for v in arr:
                                    if isinstance(v, str) and v.strip():
                                        counts[v] = counts.get(v, 0) + 1

                    for v, count in counts.items():
                        result.chunks.append({
                            "value": v,
                            "count": count,
                        })
                    result.total += len(counts)
                else:
                    res = self.client.perform_raw_text_sql(
                        f"SELECT {agg_field}, COUNT(*) as count FROM {index_name}"
                        f" WHERE {agg_field} IS NOT NULL AND {filters_expr}"
                        f" GROUP BY {agg_field}"
                    )
                    for row in res:
                        result.chunks.append({
                            "value": row[0],
                            "count": int(row[1]),
                        })
                        result.total += 1
            else:
                # only filter
                orders: list[str] = []
                if orderBy:
                    for field, order in orderBy.fields:
                        if isinstance(column_types[field], ARRAY):
                            f = field + "_sort"
                            fields_expr += f", array_to_string({field}, ',') AS {f}"
                            field = f
                        order = "ASC" if order == 0 else "DESC"
                        orders.append(f"{field} {order}")
                count_sql = f"SELECT COUNT(id) FROM {index_name} WHERE {filters_expr}"
                logger.debug("OBConnection.search with normal count sql: %s", count_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(count_sql)
                total_count = res.fetchone()[0] if res else 0
                result.total += total_count

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: normal, step: 1-count, elapsed time: {elapsed_time:.3f} seconds,"
                    f" condition: '{condition}',"
                    f" got count: {total_count}"
                )

                if total_count == 0:
                    continue

                order_by_expr = ("ORDER BY " + ", ".join(orders)) if len(orders) > 0 else ""
                limit_expr = f"LIMIT {offset}, {limit}" if limit != 0 else ""
                filter_sql = (
                    f"SELECT {fields_expr}"
                    f"  FROM {index_name}"
                    f"  WHERE {filters_expr}"
                    f"  {order_by_expr} {limit_expr}"
                )
                logger.debug("OBConnection.search with normal sql: %s", filter_sql)

                start_time = time.time()

                res = self.client.perform_raw_text_sql(filter_sql)
                rows = res.fetchall()

                elapsed_time = time.time() - start_time
                logger.info(
                    f"OBConnection.search table {index_name}, search type: normal, step: 2-query, elapsed time: {elapsed_time:.3f} seconds,"
                    f" select fields: '{output_fields}',"
                    f" condition: '{condition}',"
                    f" return rows count: {len(rows)}"
                )

                for row in rows:
                    result.chunks.append(self._row_to_entity(row, output_fields))
        return result

    def get(self, chunkId: str, indexName: str, knowledgebaseIds: list[str]) -> dict | None:
        if not self.client.check_table_exists(indexName):
            return None

        try:
            res = self.client.get(
                table_name=indexName,
                ids=[chunkId],
            )
            row = res.fetchone()
            if row is None:
                raise Exception(f"ChunkId {chunkId} not found in index {indexName}.")

            return self._row_to_entity(row, fields=list(res.keys()))
        except json.JSONDecodeError as e:
            logger.error(f"JSON decode error when getting chunk {chunkId}: {str(e)}")
            return {
                "id": chunkId,
                "error": f"Failed to parse chunk data due to invalid JSON: {str(e)}"
            }
        except Exception as e:
            logger.error(f"Error getting chunk {chunkId}: {str(e)}")
            raise

    def insert(self, documents: list[dict], indexName: str, knowledgebaseId: str = None) -> list[str]:
        if not documents:
            return []

        docs: list[dict] = []
        ids: list[str] = []
        for document in documents:
            d: dict = {}
            for k, v in document.items():
                if vector_column_pattern.match(k):
                    d[k] = v
                    continue
                if k not in column_names:
                    if "extra" not in d:
                        d["extra"] = {}
                    d["extra"][k] = v
                    continue
                if v is None:
                    d[k] = get_default_value(k)
                    continue

                if k == "kb_id" and isinstance(v, list):
                    d[k] = v[0]
                elif k == "content_with_weight" and isinstance(v, dict):
                    d[k] = json.dumps(v, ensure_ascii=False)
                elif k == "position_int":
                    d[k] = json.dumps([list(vv) for vv in v], ensure_ascii=False)
                elif isinstance(v, list):
                    # remove characters like '\t' for JSON dump and clean special characters
                    cleaned_v = []
                    for vv in v:
                        if isinstance(vv, str):
                            cleaned_str = vv.strip()
                            cleaned_str = cleaned_str.replace('\\', '\\\\')
                            cleaned_str = cleaned_str.replace('\n', '\\n')
                            cleaned_str = cleaned_str.replace('\r', '\\r')
                            cleaned_str = cleaned_str.replace('\t', '\\t')
                            cleaned_v.append(cleaned_str)
                        else:
                            cleaned_v.append(vv)
                    d[k] = json.dumps(cleaned_v, ensure_ascii=False)
                else:
                    d[k] = v

            ids.append(d["id"])
            # this is to fix https://github.com/sqlalchemy/sqlalchemy/issues/9703
            for column_name in column_names:
                if column_name not in d:
                    d[column_name] = get_default_value(column_name)

            metadata = d.get("metadata", {})
            if metadata is None:
                metadata = {}
            group_id = metadata.get("_group_id")
            title = metadata.get("_title")
            if d.get("doc_id"):
                if group_id:
                    d["group_id"] = group_id
                else:
                    d["group_id"] = d["doc_id"]
                if title:
                    d["docnm_kwd"] = title

            docs.append(d)

        logger.debug("OBConnection.insert chunks: %s", docs)

        res = []
        try:
            self.client.upsert(indexName, docs)
        except Exception as e:
            logger.error(f"OBConnection.insert error: {str(e)}")
            res.append(str(e))
        return res

    def update(self, condition: dict, newValue: dict, indexName: str, knowledgebaseId: str) -> bool:
        if not self.client.check_table_exists(indexName):
            return True

        condition["kb_id"] = knowledgebaseId
        filters = get_filters(condition)
        set_values: list[str] = []
        for k, v in newValue.items():
            if k == "remove":
                if isinstance(v, str):
                    set_values.append(f"{v} = NULL")
                else:
                    assert isinstance(v, dict), f"Expected str or dict for 'remove', got {type(newValue[k])}."
                    for kk, vv in v.items():
                        assert kk in array_columns, f"Column '{kk}' is not an array column."
                        set_values.append(f"{kk} = array_remove({kk}, {get_value_str(vv)})")
            elif k == "add":
                assert isinstance(v, dict), f"Expected str or dict for 'add', got {type(newValue[k])}."
                for kk, vv in v.items():
                    assert kk in array_columns, f"Column '{kk}' is not an array column."
                    set_values.append(f"{kk} = array_append({kk}, {get_value_str(vv)})")
            elif k == "metadata":
                assert isinstance(v, dict), f"Expected dict for 'metadata', got {type(newValue[k])}"
                set_values.append(f"{k} = {get_value_str(v)}")
                if v and "doc_id" in condition:
                    group_id = v.get("_group_id")
                    title = v.get("_title")
                    if group_id:
                        set_values.append(f"group_id = {get_value_str(group_id)}")
                    if title:
                        set_values.append(f"docnm_kwd = {get_value_str(title)}")
            else:
                set_values.append(f"{k} = {get_value_str(v)}")

        if not set_values:
            return True

        update_sql = (
            f"UPDATE {indexName}"
            f" SET {', '.join(set_values)}"
            f" WHERE {' AND '.join(filters)}"
        )
        logger.debug("OBConnection.update sql: %s", update_sql)

        try:
            self.client.perform_raw_text_sql(update_sql)
            return True
        except Exception as e:
            logger.error(f"OBConnection.update error: {str(e)}")
        return False

    def delete(self, condition: dict, indexName: str, knowledgebaseId: str) -> int:
        if not self.client.check_table_exists(indexName):
            return 0

        condition["kb_id"] = knowledgebaseId
        try:
            res = self.client.get(
                table_name=indexName,
                ids=None,
                where_clause=[text(f) for f in get_filters(condition)],
                output_column_name=["id"],
            )
            rows = res.fetchall()
            if len(rows) == 0:
                return 0
            ids = [row[0] for row in rows]
            logger.debug(f"OBConnection.delete chunks, filters: {condition}, ids: {ids}")
            self.client.delete(
                table_name=indexName,
                ids=ids,
            )
            return len(ids)
        except Exception as e:
            logger.error(f"OBConnection.delete error: {str(e)}")
        return 0

    @staticmethod
    def _row_to_entity(data: Row, fields: list[str]) -> dict:
        entity = {}
        for i, field in enumerate(fields):
            value = data[i]
            if value is None:
                continue
            entity[field] = get_column_value(field, value)
        return entity

    @staticmethod
    def _es_row_to_entity(data: dict) -> dict:
        entity = {}
        for k, v in data.items():
            if v is None:
                continue
            entity[k] = get_column_value(k, v)
        return entity

    """
    Helper functions for search result
    """

    def get_total(self, res) -> int:
        return res.total

    def get_chunk_ids(self, res) -> list[str]:
        return [row["id"] for row in res.chunks]

    def get_fields(self, res, fields: list[str]) -> dict[str, dict]:
        result = {}
        for row in res.chunks:
            data = {}
            for field in fields:
                v = row.get(field)
                if v is not None:
                    data[field] = v
            result[row["id"]] = data
        return result

    # copied from query.FulltextQueryer
    def is_chinese(self, line):
        arr = re.split(r"[ \t]+", line)
        if len(arr) <= 3:
            return True
        e = 0
        for t in arr:
            if not re.match(r"[a-zA-Z]+$", t):
                e += 1
        return e * 1.0 / len(arr) >= 0.7

    def highlight(self, txt: str, tks: str, question: str, keywords: list[str]) -> Optional[str]:
        if not txt or not keywords:
            return None

        highlighted_txt = txt

        if question and not self.is_chinese(question):
            highlighted_txt = re.sub(
                r"(^|\W)(%s)(\W|$)" % re.escape(question),
                r"\1<em>\2</em>\3", highlighted_txt,
                flags=re.IGNORECASE | re.MULTILINE,
            )
            if re.search(r"<em>[^<>]+</em>", highlighted_txt, flags=re.IGNORECASE | re.MULTILINE):
                return highlighted_txt

            for keyword in keywords:
                highlighted_txt = re.sub(
                    r"(^|\W)(%s)(\W|$)" % re.escape(keyword),
                    r"\1<em>\2</em>\3", highlighted_txt,
                    flags=re.IGNORECASE | re.MULTILINE,
                )
            if len(re.findall(r'</em><em>', highlighted_txt)) > 0 or len(
                re.findall(r'</em>\s*<em>', highlighted_txt)) > 0:
                return highlighted_txt
            else:
                return None

        if not tks:
            tks = rag_tokenizer.tokenize(txt)
        tokens = tks.split()
        if not tokens:
            return None

        last_pos = len(txt)

        for i in range(len(tokens) - 1, -1, -1):
            token = tokens[i]
            token_pos = highlighted_txt.rfind(token, 0, last_pos)
            if token_pos != -1:
                if token in keywords:
                    highlighted_txt = (
                        highlighted_txt[:token_pos] +
                        f'<em>{token}</em>' +
                        highlighted_txt[token_pos + len(token):]
                    )
                last_pos = token_pos
        return re.sub(r'</em><em>', '', highlighted_txt)

    def get_highlight(self, res, keywords: list[str], fieldnm: str):
        ans = {}
        if len(res.chunks) == 0 or len(keywords) == 0:
            return ans

        for d in res.chunks:
            txt = d.get(fieldnm)
            if not txt:
                continue

            tks = d.get("content_ltks") if fieldnm == "content_with_weight" else ""
            highlighted_txt = self.highlight(txt, tks, " ".join(keywords), keywords)
            if highlighted_txt:
                ans[d["id"]] = highlighted_txt
        return ans

    def get_aggregation(self, res, fieldnm: str):
        if len(res.chunks) == 0:
            return []

        counts = {}
        result = []
        for d in res.chunks:
            if "value" in d and "count" in d:
                # directly use the aggregation result
                result.append((d["value"], d["count"]))
            elif fieldnm in d:
                # aggregate the values of specific field
                v = d[fieldnm]
                if isinstance(v, list):
                    for vv in v:
                        if isinstance(vv, str) and vv.strip():
                            counts[vv] = counts.get(vv, 0) + 1
                elif isinstance(v, str) and v.strip():
                    counts[v] = counts.get(v, 0) + 1

        if len(counts) > 0:
            for k, v in counts.items():
                result.append((k, v))

        return result

    """
    SQL
    """

    def sql(sql: str, fetch_size: int, format: str):
        # TODO: execute the sql generated by text-to-sql
        return None
